import abc

import numpy as np
from metaworld.policies.policy import Policy
from metaworld.policies.action import Action

from diffgro.environments.metaworld.policies.skill import *


class HeuristicPolicy(Policy):
    def __init__(self):
        super().__init__()
        self.phase = 0
        self.skills = None

        self.is_done = False

    @abc.abstractmethod
    def _initialize_skill(o_d):
        pass

    @abc.abstractmethod
    def _parse_obs(obs):
        pass

    def get_action(self, obs, return_skill=False):
        o_d = self._parse_obs(obs)

        skill = self._get_skill(o_d)
        action = self._decode_skill(skill, o_d)

        if return_skill:
            return action.array, skill.get_name(o_d["hand_pos"])
        return action.array  # [delta_pos, grab_effort]

    def _get_skill(self, o_d):
        if self.skills is None:
            self._initialize_skill(o_d)

        curr_skill = self.skills[self.phase]
        if curr_skill.is_done(o_d["hand_pos"]):
            self.phase += 1
            if self.phase >= len(self.skills):
                self.is_done = True
                self.phase = len(self.skills) - 1

        return curr_skill

    def _decode_skill(self, skill, o_d):
        action = Action({"delta_pos": np.arange(3), "grab_effort": 3})
        action["delta_pos"] = skill.get_delta_pos(o_d["hand_pos"])
        action["grab_effort"] = skill.get_grab_effort()

        return action

    def reset(self):
        self.phase = 0
        self.skills = None

        self.is_done = False


############# Button Press V2 Policy #############


class SkillButtonPressV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["button_pos"] + np.array([0.0, 0.0, -0.07])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x),
            MoveZSkill(target_z),
            MoveYSkill(target_y + 0.02, atol=0.06, p=3.0),
            PushSkill(target_pos + np.array([0.0, 0.2, 0])),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "hand_closed": obs[3],
            "button_pos": obs[4:7],
            "unused_info": obs[7:],
        }


class SkillButtonPressVariantV2Policy(SkillButtonPressV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["button_pos"] + np.array([0.0, 0.0, -0.07])
        target_x, target_y, target_z = target_pos

        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveXSkill(target_x),
            MoveZSkill(target_z),
            MoveYSkill(target_y + 0.02, atol=0.06),
            PushSkill(target_pos + np.array([0.0, 0.2, 0]), p=p),
        ]


############# Door Open V2 Policy #############


class SkillDoorOpenV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["door_pos"] + np.array([-0.05, 0, 0])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveZSkill(target_z + 0.12, grab_effort=1),
            MoveXSkill(target_x + 0.04, grab_effort=1),
            MoveYSkill(target_y + 0.02, grab_effort=1),
            MoveZSkill(target_z + 0.04, grab_effort=1),
            PullSkill(target_pos + np.array([-0.4, -0.2, 0]), grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper": obs[3],
            "door_pos": obs[4:7],
            "unused_info": obs[7:],
        }


class SkillDoorOpenVariantV2Policy(SkillDoorOpenV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["door_pos"] + np.array([-0.05, 0, 0])

        target_x, target_y, target_z = target_pos
        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveZSkill(target_z + 0.12, grab_effort=1),
            MoveXSkill(target_x + 0.04, grab_effort=1),
            MoveYSkill(target_y + 0.02, grab_effort=1),
            MoveZSkill(target_z + 0.02, grab_effort=1),
            PullSkill(target_pos + np.array([-0.4, -0.2, 0]), p=p, grab_effort=1),
        ]


############# Drawer Close V2 Policy #############


class SkillDrawerCloseV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["drwr_pos"] + np.array([0.0, 0.0, -0.02])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveZSkill(target_z + 0.23, grab_effort=1),
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y - 0.075, grab_effort=1),
            MoveZSkill(target_z, grab_effort=1),
            PushSkill(target_pos + np.array([0, 0.2, 0]), grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper": obs[3],
            "drwr_pos": obs[4:7],
            "unused_info": obs[7:],
        }


class SkillDrawerCloseVariantV2Policy(SkillDrawerCloseV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["drwr_pos"] + np.array([0.0, 0.0, -0.02])
        target_x, target_y, target_z = target_pos

        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveZSkill(target_z + 0.23, grab_effort=1),
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y - 0.075, grab_effort=1),
            MoveZSkill(target_z, grab_effort=1),
            MoveYSkill(target_y - 0.025, grab_effort=1),
            PushSkill(target_pos + np.array([0, 0.2, 0]), p=p, grab_effort=1),
        ]


############# Drawer Open V2 Policy #############


class SkillDrawerOpenV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["drwr_pos"] + np.array([0.0, 0.0, -0.02])
        target_x, target_y, target_z = target_pos
        self.skills = [
            # MoveZSkill(target_z + 0.12, grab_effort=-1.),
            MoveXSkill(target_x, grab_effort=-1.0),
            MoveYSkill(target_y + 0.02, grab_effort=-1.0),
            MoveZSkill(target_z + 0.04, grab_effort=-1.0),
            PullSkill(target_pos + np.array([0, -0.2, 0]), grab_effort=-1.0),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper": obs[3],
            "drwr_pos": obs[4:7],
            "unused_info": obs[7:],
        }


class SkillDrawerOpenVariantV2Policy(SkillDrawerOpenV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant


############# Peg Insert Side V2 Policy #############


class SkillPegInsertSideV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["peg_pos"]
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1, atol=0.02),
            MoveZSkill(target_z + 0.025, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_pos = np.array([-0.35, o_d["goal_pos"][1], 0.16])
        target_x, target_y, target_z = target_pos
        self.skills += [
            MoveZSkill(target_z, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            MoveXSkill(target_x, grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper_distance_apart": obs[3],
            "peg_pos": obs[4:7],
            "peg_rot": obs[7:11],
            "goal_pos": obs[-3:],
            "unused_info_curr_obs": obs[11:18],
            "_prev_obs": obs[18:36],
        }


class SkillPegInsertSideVariantV2Policy(SkillPegInsertSideV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["peg_pos"]
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1, atol=0.02),
            MoveZSkill(target_z + 0.025, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_pos = np.array([-0.35, o_d["goal_pos"][1], 0.16])
        target_x, target_y, target_z = target_pos
        self.skills += [
            MoveZSkill(target_z + 0.01, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            MoveXSkill(target_x, grab_effort=1),
        ]


############# Push V2 Policy #############


class SkillPushV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["puck_pos"] + np.array([-0.005, 0, 0])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1),
            MoveZSkill(target_z + 0.03, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_pos = o_d["goal_pos"]
        target_x, target_y, target_z = target_pos
        self.skills += [
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper_distance_apart": obs[3],
            "puck_pos": obs[4:7],
            "puck_rot": obs[7:11],
            "goal_pos": obs[-3:],
            "unused_info_curr_obs": obs[11:18],
            "_prev_obs": obs[18:36],
        }


class SkillPushVariantV2Policy(SkillPushV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["puck_pos"] + np.array([-0.005, 0, 0])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=-1, atol=0.001),
            MoveYSkill(target_y, grab_effort=-1),
            MoveZSkill(target_z + 0.03, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_pos = o_d["goal_pos"]
        target_x, target_y, target_z = target_pos
        if self.variant["wall_existence"]:
            self.skills += [
                MoveXSkill(target_x - 0.1, grab_effort=1),
                MoveYSkill(target_y, grab_effort=1),
                MoveXSkill(target_x, grab_effort=1),
            ]
        else:
            self.skills += [
                MoveXSkill(target_x, grab_effort=1),
                MoveYSkill(target_y, grab_effort=1),
            ]


############# Reach V2 Policy #############


class SkillReachV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["goal_pos"]
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveZSkill(target_z),
            MoveXSkill(target_x),
            MoveYSkill(target_y),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "unused_1": obs[3],
            "puck_pos": obs[4:7],
            "unused_2": obs[7:-3],
            "goal_pos": obs[-3:],
        }


class SkillReachVariantV2Policy(SkillReachV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["goal_pos"]
        target_x, target_y, target_z = target_pos

        if not self.variant["wall_existence"]:
            self.skills = [
                MoveZSkill(target_z + 0.01),
                MoveXSkill(target_x),
                MoveYSkill(target_y),
            ]
        else:
            self.skills = [
                MoveZSkill(target_z + 0.1),
                MoveXSkill(target_x),
                MoveYSkill(target_y),
                MoveZSkill(target_z),
            ]


############# Window Open V2 Policy ##############


class SkillWindowOpenV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["wndw_pos"] + np.array([-0.03, -0.03, -0.08])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            PushSkill(target_pos + np.array([0.3, 0, 0]), grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "unused_gripper_open": obs[3],
            "wndw_pos": obs[4:7],
            "unused_info": obs[7:-3],
            "goal_pos": obs[-3:],
        }


class SkillWindowOpenVariantV2Policy(SkillWindowOpenV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = (
            o_d["wndw_pos"]
            + np.array([-0.03, -0.02, -0.03])
            + np.array(self.variant["handle_position"])
        )
        target_x, target_y, target_z = target_pos

        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            MoveZSkill(target_z, grab_effort=1),
            PushSkill(target_pos + np.array([0.3, 0, 0]), p=p, grab_effort=1),
        ]


############# Window Close V2 Policy #############


class SkillWindowCloseV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["wndw_pos"] + np.array([0.03, -0.03, -0.08])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x + 0.2, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            PullSkill(target_pos + np.array([-0.2, 0, 0]), grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "unused_gripper_open": obs[3],
            "wndw_pos": obs[4:7],
            "unused_info": obs[7:-3],
            "goal_pos": obs[-3:],
        }


class SkillWindowCloseVariantV2Policy(SkillWindowCloseV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = (
            o_d["wndw_pos"]
            + np.array([0.03, -0.03, -0.08])
            + np.array(self.variant["handle_position"])
        )
        target_x, target_y, target_z = target_pos

        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveXSkill(target_x + 0.2, grab_effort=1, p=5),
            MoveYSkill(target_y, grab_effort=1),
            MoveZSkill(target_z, grab_effort=1),
            PullSkill(target_pos + np.array([-0.2, 0, 0]), p=p, grab_effort=1),
        ]


############# Faucet Open V2 Policy #############


class SkillFaucetOpenV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["faucet_pos"]

        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x - 0.02, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            MoveZSkill(target_z, grab_effort=1),
            PullSkill(target_pos + np.array([0.4, 0.2, 0]), grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "unused_gripper": obs[3],
            "faucet_pos": obs[4:7],
            "unused_info": obs[7:-3],
            "target_pos": obs[-3:],
        }


class SkillFaucetOpenVariantV2Policy(SkillFaucetOpenV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["faucet_pos"]
        target_x, target_y, target_z = target_pos

        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveXSkill(target_x - 0.08, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            MoveZSkill(target_z + 0.03, grab_effort=1),
            PullSkill(target_pos + np.array([0.15, 0.15, 0]), p=p, grab_effort=1),
        ]
        return self.skills


##########################################################################################


class SkillPushBackV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["puck_pos"] + np.array([-0.005, 0, 0])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1),
            MoveZSkill(target_z + 0.03, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_pos = o_d["goal_pos"]
        target_x, target_y, target_z = target_pos
        self.skills += [
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper_distance_apart": obs[3],
            "puck_pos": obs[4:7],
            "puck_rot": obs[7:11],
            "goal_pos": obs[-3:],
            "unused_info_curr_obs": obs[11:18],
            "_prev_obs": obs[18:36],
        }


class SkillPushBackVariantV2Policy(SkillPushBackV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["puck_pos"] + np.array([-0.005, 0, 0])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=-1, atol=0.001),
            MoveYSkill(target_y, grab_effort=-1),
            MoveZSkill(target_z + 0.03, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_pos = o_d["goal_pos"]
        target_x, target_y, target_z = target_pos
        self.skills += [
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
        ]


##########################################################################################


class SkillPlateSlideV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["puck_pos"] + np.array([0.0, -0.055, 0.03])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1),
            MoveZSkill(target_z + 0.02, grab_effort=-1),
        ]

        target_x, target_y = o_d["shelf_x"], 0.9
        self.skills += [
            MoveXSkill(target_x, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "unused_1": obs[3],
            "puck_pos": obs[4:7],
            "unused_2": obs[7:-3],
            "shelf_x": obs[-3],
            "unused_3": obs[-2:],
        }


class SkillPlateSlideVariantV2Policy(SkillPlateSlideV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant


##########################################################################################


class SkillPlateSlideSideV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["puck_pos"] + np.array([0.0, -0.055, 0.03])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1),
            MoveZSkill(target_z + 0.02, grab_effort=-1),
        ]

        target_x, target_y, _ = o_d["goal_pos"]
        self.skills += [
            MoveYSkill(target_y, grab_effort=-1),
            MoveXSkill(target_x, grab_effort=-1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "unused_1": obs[3],
            "puck_pos": obs[4:7],
            "unused_2": obs[7:-3],
            "goal_pos": obs[-3:],
        }


class SkillPlateSlideSideVariantV2Policy(SkillPlateSlideSideV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant


class SkillPlateSlideBackVariantV2Policy(SkillPlateSlideSideV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant


class SkillPlateSlideBackSideVariantV2Policy(SkillPlateSlideSideV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant


class SkillPegUnplugSideV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["peg_pos"] + np.array([-0.02, 0.0, 0.035])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveZSkill(target_z + 0.06, grab_effort=-1),
            MoveXSkill(target_x, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1, atol=0.02),
            MoveZSkill(target_z, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_pos = o_d["goal_pos"]
        target_x, _, _ = target_pos
        self.skills += [
            MoveXSkill(target_x, grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "unused_gripper": obs[3],
            "peg_pos": obs[4:7],
            "unused_info": obs[7:-3],
            "goal_pos": obs[-3:],
        }


class SkillPegUnplugSideVariantV2Policy(SkillPegUnplugSideV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant


class SkillPickPlaceV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["puck_pos"] + np.array([-0.002, 0, 0])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1),
            MoveZSkill(target_z + 0.03, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_pos = o_d["goal_pos"]
        target_x, target_y, target_z = target_pos
        self.skills += [
            MoveZSkill(target_z, grab_effort=1),
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper_distance_apart": obs[3],
            "puck_pos": obs[4:7],
            "puck_rot": obs[7:11],
            "goal_pos": obs[-3:],
            "unused_info_curr_obs": obs[11:18],
            "_prev_obs": obs[18:36],
        }


class SkillPickPlaceVariantV2Policy(SkillPickPlaceV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant
